/*
 * printf.c
 *
 * Copyright (C) 2018 Aleksandar Andrejevic <theflash@sdf.lonestar.org>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU Affero General Public License as
 * published by the Free Software Foundation, either version 3 of the
 * License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU Affero General Public License for more details.
 *
 * You should have received a copy of the GNU Affero General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <stdio.h>
#include <string.h>
#include <stdlib.h>
#include <monolithium.h>
#include "io_priv.h"

#define __CRT_PRINTF_FLAG_ALIGN_LEFT         (1 << 0)
#define __CRT_PRINTF_FLAG_PLUS               (1 << 1)
#define __CRT_PRINTF_FLAG_SPACE              (1 << 2)
#define __CRT_PRINTF_FLAG_EXTRA              (1 << 3)
#define __CRT_PRINTF_FLAG_ZERO_PAD           (1 << 4)
#define __CRT_PRINTF_FLAG_EXTERNAL_WIDTH     (1 << 5)
#define __CRT_PRINTF_FLAG_EXTERNAL_PRECISION (1 << 6)
#define __CRT_PRINTF_FLAG_DEFAULT_WIDTH      (1 << 7)
#define __CRT_PRINTF_FLAG_DEFAULT_PRECISION  (1 << 8)

typedef struct
{
    FILE *stream;
    char *string;
    size_t size;
} __crt_stream_or_string_t;

static int __crt_strputc(__crt_stream_or_string_t *str, char c)
{
    if (str->stream)
    {
        return fputc_unlocked(c, str->stream) != EOF;
    }
    else if (str->size > 1)
    {
        *str->string++ = c;
        str->size--;
        return 1;
    }

    return 0;
}

static int __crt_strputs(__crt_stream_or_string_t *str, const char *s)
{
    if (str->stream)
    {
        return fputs_unlocked(s, str->stream) != EOF ? strlen(s) : 0;
    }
    else if (str->size > 1)
    {
        int length = strlen(s);
        strncpy(str->string, s, str->size - 1);
        if (length > (str->size - 1)) length = str->size - 1;
        str->string += length;
        str->size -= length;
    }

    return 0;
}

static inline int __crt_vstrprintf(__crt_stream_or_string_t *str, const char *format, va_list ap)
{
    int ret = 0;
    if (str->stream) syscall_wait_mutex(str->stream->mutex, NO_TIMEOUT);

    const char *ptr;
    for (ptr = format; *ptr; ptr++)
    {
        if (*ptr == '%')
        {
            unsigned long flags = 0;
            unsigned int width = 0, precision = 0;

            ptr++;

            while (*ptr)
            {
                if (*ptr == '-') flags |= __CRT_PRINTF_FLAG_ALIGN_LEFT;
                else if (*ptr == '+') flags |= __CRT_PRINTF_FLAG_PLUS;
                else if (*ptr == ' ') flags |= __CRT_PRINTF_FLAG_SPACE;
                else if (*ptr == '#') flags |= __CRT_PRINTF_FLAG_EXTRA;
                else if (*ptr == '0') flags |= __CRT_PRINTF_FLAG_ZERO_PAD;
                else break;

                ptr++;
            }

            if (*ptr == '*') flags |= __CRT_PRINTF_FLAG_EXTERNAL_WIDTH;
            else if (isdigit(*ptr)) width = strtoul(ptr, (char**)&ptr, 10);
            else flags |= __CRT_PRINTF_FLAG_DEFAULT_WIDTH;

            if (*ptr == '.')
            {
                ptr++;
                if (*ptr == '*') flags |= __CRT_PRINTF_FLAG_EXTERNAL_PRECISION;
                else if (isdigit(*ptr)) precision = strtoul(ptr, (char**)&ptr, 10);
                else flags |= __CRT_PRINTF_FLAG_DEFAULT_PRECISION;
            }
            else
            {
                flags |= __CRT_PRINTF_FLAG_DEFAULT_PRECISION;
            }

            int variable_size = 0;

            if (*ptr == 'h')
            {
                variable_size--;

                if (*++ptr == 'h')
                {
                    variable_size--;
                    ptr++;
                }
            }
            else if (*ptr == 'l')
            {
                variable_size--;

                if (*++ptr == 'l')
                {
                    variable_size--;
                    ptr++;
                }
            }

            const char *data = NULL;
            char buffer[512];
            int radix = 10;

            switch (*ptr)
            {
            case 's':
                data = va_arg(ap, const char*);
                break;

            case 'c':
                data = buffer;
                buffer[0] = (char)va_arg(ap, int);
                buffer[1] = '\0';
                break;

            case 'd':
            case 'i':
                if (flags & __CRT_PRINTF_FLAG_DEFAULT_WIDTH) width = 0;
                if (flags & __CRT_PRINTF_FLAG_DEFAULT_PRECISION) precision = 1;
                data = buffer;

                switch (variable_size)
                {
                case -2: itoa((char)va_arg(ap, int), buffer, 10); break;
                case -1: itoa((short)va_arg(ap, int), buffer, 10); break;
                case  0: itoa(va_arg(ap, int), buffer, 10); break;
                case  1: ltoa(va_arg(ap, long), buffer, 10); break;
                case  2: lltoa(va_arg(ap, long long), buffer, 10); break;
                }
                break;

            case 'o':
                radix = 8;
                goto read_variable;
            case 'x':
            case 'X':
                radix = 16;
            case 'u':
            read_variable:
                if (flags & __CRT_PRINTF_FLAG_DEFAULT_WIDTH) width = 0;
                if (flags & __CRT_PRINTF_FLAG_DEFAULT_PRECISION) precision = 1;
                data = buffer;

                switch (variable_size)
                {
                case -2: uitoa((unsigned char)va_arg(ap, unsigned int), buffer, radix); break;
                case -1: uitoa((unsigned short)va_arg(ap, unsigned int), buffer, radix); break;
                case  0: uitoa(va_arg(ap, unsigned int), buffer, radix); break;
                case  1: ultoa(va_arg(ap, unsigned long), buffer, radix); break;
                case  2: ulltoa(va_arg(ap, unsigned long long), buffer, radix); break;
                }
                break;

            case '%':
                data = buffer;
                strcpy(buffer, "%");
                break;
            }

            if (flags & __CRT_PRINTF_FLAG_EXTERNAL_WIDTH) width = va_arg(ap, unsigned int);
            if (flags & __CRT_PRINTF_FLAG_EXTERNAL_PRECISION) precision = va_arg(ap, unsigned int);

            int length = strlen(data);
            char padding = (flags &__CRT_PRINTF_FLAG_ZERO_PAD) ? '0' : ' ';

            switch (*ptr)
            {
            case 's':
                if (!(flags & __CRT_PRINTF_FLAG_ALIGN_LEFT))
                {
                    while (length < width)
                    {
                        __crt_strputc(str, padding);
                        length++;
                    }
                }

                if (precision)
                {
                    while (*data && precision > 0)
                    {
                        __crt_strputc(str, *data++);
                        precision--;
                    }
                }
                else
                {
                    __crt_strputs(str, data ? data : "(null)");
                }
                break;

            case 'd':
            case 'i':
            case 'o':
            case 'u':
            case 'x':
            case 'X':
                if (flags & (__CRT_PRINTF_FLAG_PLUS | __CRT_PRINTF_FLAG_SPACE))
                {
                    if (*data == '-')
                    {
                        __crt_strputc(str, '-');
                        data++;
                    }
                    else
                    {
                        __crt_strputc(str, (flags & __CRT_PRINTF_FLAG_PLUS) ? '+' : ' ');
                        length++;
                    }
                }

                if (!(flags & __CRT_PRINTF_FLAG_ALIGN_LEFT) && width > precision)
                {
                    while (length < width - precision)
                    {
                        __crt_strputc(str, padding);
                        length++;
                    }

                    while (length < width)
                    {
                        __crt_strputc(str, '0');
                        length++;
                    }
                }
                else
                {
                    while (length < precision)
                    {
                        __crt_strputc(str, '0');
                        length++;
                    }
                }

                __crt_strputs(str, data);
                break;
            }

            while (length < width)
            {
                __crt_strputc(str, padding);
                length++;
            }
        }
        else
        {
            ret += __crt_strputc(str, *ptr);
        }
    }

    if (str->string) *str->string = '\0';
    if (str->stream) syscall_release_mutex(str->stream->mutex);
    return ret;
}

int vsnprintf(char *string, size_t size, const char *format, va_list ap)
{
    __crt_stream_or_string_t str = { .stream = NULL, .string = string, .size = size };
    return __crt_vstrprintf(&str, format, ap);
}

int vsprintf(char *str, const char *format, va_list ap)
{
    return vsnprintf(str, -1, format, ap);
}

int vprintf(const char *format, va_list ap)
{
    return vfprintf(stdout, format, ap);
}

int vfprintf(FILE *stream, const char *format, va_list ap)
{
    __crt_stream_or_string_t str = { .stream = stream, .string = NULL, .size = 0 };
    return __crt_vstrprintf(&str, format, ap);
}

int fprintf(FILE *stream, const char *format, ...)
{
    va_list ap;
    va_start(ap, format);
    int ret = vfprintf(stream, format, ap);
    va_end(ap);
    return ret;
}

int snprintf(char *str, size_t size, const char *format, ...)
{
    va_list ap;
    va_start(ap, format);
    int ret = vsnprintf(str, size, format, ap);
    va_end(ap);
    return ret;
}

int sprintf(char *str, const char *format, ...)
{
    va_list ap;
    va_start(ap, format);
    int ret = vsprintf(str, format, ap);
    va_end(ap);
    return ret;
}

int printf(const char *format, ...)
{
    va_list ap;
    va_start(ap, format);
    int ret = vprintf(format, ap);
    va_end(ap);
    return ret;
}
